{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Better ML Engineering with ML Metadata\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tfx/tutorials/mlmd_tutorial\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tfx/blob/master/docs/tutorials/mlmd_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/tools/templates/notebook.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/tools/templates/notebook.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "Assume a scenario where you set up a production ML pipeline to classify pictures of iris flowers. The pipeline ingests your training data, trains and evaluates a model, and pushes it to production. \n",
        "\n",
        "However, when you later try using this model with a larger dataset that contains images of different kinds of flowers, you observe that your model does not behave as expected and starts classifying roses and lilies as types of irises.\n",
        "\n",
        "At this point, you are interested in knowing:\n",
        "\n",
        "* What is the most efficient way to debug the model when the only available artifact is the model in production?\n",
        "* Which training dataset was used to train the model?\n",
        "* Which training run led to this erroneous model?\n",
        "* Where are the model evaluation results?\n",
        "* Where to begin debugging?\n",
        "\n",
        "[ML Metadata (MLMD)](https://github.com/google/ml-metadata) is a library that leverages the metadata associated with ML models to help you answer these questions and more. A helpful analogy is to think of this metadata as the equivalent of logging in software development. MLMD enables you to reliably track the artifacts and lineage associated with the various components of your ML pipeline.\n",
        "\n",
        "In this tutorial, you set up a TFX Pipeline to create a model that classifies Iris flowers into three species - Iris setosa, Iris virginica, and Iris versicolor based on the length and width measurements of their petals and sepals. You then use MLMD to track the lineage of pipeline components."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3rGF8hLibz6p"
      },
      "source": [
        "## TFX Pipelines in Colab\n",
        "\n",
        "Colab is a lightweight development environment which differs significantly from a production environment. In production, you may have various pipeline components like data ingestion, transformation, model training, run histories, etc. across multiple, distributed systems. For this tutorial, you should be aware that siginificant differences exist in Orchestration and Metadata storage - it is all handled locally within Colab. Learn more about TFX in Colab [here](https://www.tensorflow.org/tfx/tutorials/tfx/components_keras#background).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Import all required libraries."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mQV-Cget1S8t"
      },
      "source": [
        "### Install and import TFX"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "82jOhrcA36YA"
      },
      "outputs": [],
      "source": [
        " !pip install --quiet tfx==0.23.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OD2cRhwM3ez2"
      },
      "source": [
        "You must restart the Colab runtime after installing TFX. Select **Runtime > Restart runtime** from the Colab menu.\n",
        "\n",
        "Do not proceed with the rest of this tutorial without first restarting the runtime."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ohOztGn2wc1z"
      },
      "source": [
        "### Import other libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IqR2PQG4ZaZ0"
      },
      "outputs": [],
      "source": [
        "import base64\n",
        "import csv\n",
        "import json\n",
        "import os\n",
        "import requests\n",
        "import tempfile\n",
        "import urllib\n",
        "import pprint\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "pp = pprint.PrettyPrinter()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KY8ROY_x1wZ2"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tfx"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hWCiTxFo32jk"
      },
      "source": [
        "Import [TFX component](https://tensorflow.google.cn/tfx/tutorials/tfx/components_keras) classes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qgXWGGmK1rhP"
      },
      "outputs": [],
      "source": [
        "from tfx.components.evaluator.component import Evaluator\n",
        "from tfx.components.example_gen.csv_example_gen.component import CsvExampleGen\n",
        "from tfx.components.pusher.component import Pusher\n",
        "from tfx.components.schema_gen.component import SchemaGen\n",
        "from tfx.components.statistics_gen.component import StatisticsGen\n",
        "from tfx.components.trainer.component import Trainer\n",
        "from tfx.components.base import executor_spec\n",
        "from tfx.components.trainer.executor import GenericExecutor\n",
        "from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext\n",
        "from tfx.proto import evaluator_pb2\n",
        "from tfx.proto import pusher_pb2\n",
        "from tfx.proto import trainer_pb2\n",
        "from tfx.utils.dsl_utils import external_input\n",
        "\n",
        "from tensorflow_metadata.proto.v0 import anomalies_pb2\n",
        "from tensorflow_metadata.proto.v0 import schema_pb2\n",
        "from tensorflow_metadata.proto.v0 import statistics_pb2\n",
        "\n",
        "import tensorflow_model_analysis as tfma\n",
        "\n",
        "from tfx.components import ResolverNode\n",
        "from tfx.dsl.experimental import latest_blessed_model_resolver\n",
        "from tfx.types import Channel\n",
        "from tfx.types.standard_artifacts import Model\n",
        "from tfx.types.standard_artifacts import ModelBlessing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JKo31y2L5hCy"
      },
      "source": [
        "Import the MLMD library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FNYHX6zA5gE5"
      },
      "outputs": [],
      "source": [
        "import ml_metadata as mlmd\n",
        "from ml_metadata.proto import metadata_store_pb2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UhNtHfuxCGVy"
      },
      "source": [
        "## Download the dataset\n",
        "\n",
        "Download the [Iris dataset](https://archive.ics.uci.edu/ml/datasets/iris) dataset to use in this tutorial. The dataset contains data about the length and width measurements of sepals and petals for 150 Iris flowers. You use this data to classify irises into one of three species - Iris setosa, Iris virginica, and Iris versicolor."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B_NibNnjzGHu"
      },
      "outputs": [],
      "source": [
        "DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/iris/data/iris.csv'\n",
        "_data_root = tempfile.mkdtemp(prefix='tfx-data')\n",
        "_data_filepath = os.path.join(_data_root, \"iris.csv\")\n",
        "urllib.request.urlretrieve(DATA_PATH, _data_filepath)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8NXg2bGA19HJ"
      },
      "source": [
        "## Create an InteractiveContext\n",
        "\n",
        "To run TFX components interactively in this notebook, create an `InteractiveContext`. The `InteractiveContext` uses a temporary directory with an ephemeral MLMD database instance. Note that calls to `InteractiveContext` are no-ops outside the Colab environment.\n",
        "\n",
        "In general, it is a good practice to group similar pipeline runs under a `Context`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bytrDFKh40mi"
      },
      "outputs": [],
      "source": [
        "interactive_context = InteractiveContext()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e-58fa9S6Nao"
      },
      "source": [
        "## Construct the TFX Pipeline\n",
        "\n",
        "A TFX pipeline consists of several components that perform different aspects of the ML workflow. In this notebook, you create and run the `ExampleGen`, `StatisticsGen`, `SchemaGen`, and `TrainerGen` components and use the `Evaluator` and `Pusher` component to evaluate and push the trained model. \n",
        "\n",
        "Refer to the [components tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/components_keras) for more information on TFX pipeline components."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "urh3FTb81yyM"
      },
      "source": [
        "Note: Constructing a TFX Pipeline by setting up the individual components involves a lot of boilerplate code. For the purpose of this tutorial, it is alright if you do not fully understand every line of code in the pipeline setup. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bnnq7Gf8CHZJ"
      },
      "source": [
        "### Instantiate and run the ExampleGen Component"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H9zaBZh3C_9x"
      },
      "outputs": [],
      "source": [
        "input_data = external_input(_data_root)\n",
        "example_gen = CsvExampleGen(input=input_data)\n",
        "\n",
        "# Run the ExampleGen component using the InteractiveContext\n",
        "interactive_context.run(example_gen)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nqxye_p1DLmf"
      },
      "source": [
        "### Instantiate and run the StatisticsGen Component"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s67sHU_vDRds"
      },
      "outputs": [],
      "source": [
        "statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])\n",
        "\n",
        "# Run the StatisticsGen component using the InteractiveContext\n",
        "interactive_context.run(statistics_gen)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xib9oRb_ExjJ"
      },
      "source": [
        "### Instantiate and run the SchemaGen Component"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "csmD4CSUE3JT"
      },
      "outputs": [],
      "source": [
        "infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'],\n",
        "                         infer_feature_shape = True)\n",
        "\n",
        "# Run the SchemaGen component using the InteractiveContext\n",
        "interactive_context.run(infer_schema)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_pYNlw7BHUjP"
      },
      "source": [
        "### Instantiate and run the Trainer Component\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MTxf8xs_kKfG"
      },
      "outputs": [],
      "source": [
        "# Define the module file for the Trainer component\n",
        "trainer_module_file = 'iris_trainer.py'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f3nLHEmUkRUw"
      },
      "outputs": [],
      "source": [
        "%%writefile {trainer_module_file}\n",
        "\n",
        "# Define the training algorithm for the Trainer module file\n",
        "import os\n",
        "from typing import List, Text\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "\n",
        "from tfx.components.trainer.executor import TrainerFnArgs\n",
        "from tfx.components.trainer.fn_args_utils import FnArgs\n",
        "\n",
        "# The iris dataset has 150 records, and is split into training and evaluation \n",
        "# datasets in a 2:1 split\n",
        "\n",
        "_TRAIN_DATA_SIZE = 100\n",
        "_EVAL_DATA_SIZE = 50\n",
        "_TRAIN_BATCH_SIZE = 100\n",
        "_EVAL_BATCH_SIZE = 50\n",
        "\n",
        "# Features used for classification - sepal length and width, petal length and\n",
        "# width, and variety (species of flower)\n",
        "\n",
        "_FEATURES = {\n",
        "    'sepal_length': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0),\n",
        "    'sepal_width': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0),\n",
        "    'petal_length': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0),\n",
        "    'petal_width': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0),\n",
        "    'variety': tf.io.FixedLenFeature([], dtype=tf.int64, default_value=0)\n",
        "}\n",
        "\n",
        "_LABEL_KEY = 'variety'\n",
        "\n",
        "_FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']\n",
        "\n",
        "def _gzip_reader_fn(filenames):\n",
        "  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')\n",
        "\n",
        "def _input_fn(file_pattern: List[Text],\n",
        "              batch_size: int = 200):\n",
        "  dataset = tf.data.experimental.make_batched_features_dataset(\n",
        "            file_pattern=file_pattern,\n",
        "            batch_size=batch_size,\n",
        "            features=_FEATURES,\n",
        "            reader=_gzip_reader_fn,\n",
        "            label_key=_LABEL_KEY)\n",
        "  \n",
        "  return dataset\n",
        "  \n",
        "def _build_keras_model():\n",
        "  inputs = [keras.layers.Input(shape = (1,), name = f) for f in _FEATURE_KEYS]\n",
        "  d = keras.layers.concatenate(inputs)\n",
        "  d = keras.layers.Dense(8, activation = 'relu')(d)\n",
        "  d = keras.layers.Dense(8, activation = 'relu')(d)\n",
        "  outputs = keras.layers.Dense(3, activation = 'softmax')(d)\n",
        "  model = keras.Model(inputs=inputs, outputs=outputs)\n",
        "  model.compile(optimizer = 'adam',\n",
        "                loss = 'sparse_categorical_crossentropy',\n",
        "                metrics= [keras.metrics.SparseCategoricalAccuracy()])\n",
        "  return model\n",
        "\n",
        "def run_fn(fn_args: TrainerFnArgs):\n",
        "  train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE)\n",
        "  eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE)\n",
        "  \n",
        "  model = _build_keras_model()\n",
        "\n",
        "  steps_per_epoch = _TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE\n",
        "\n",
        "  model.fit(train_dataset, \n",
        "            epochs=int(fn_args.train_steps / steps_per_epoch),\n",
        "            steps_per_epoch=steps_per_epoch,\n",
        "            validation_data=eval_dataset,\n",
        "            validation_steps=fn_args.eval_steps)\n",
        "  model.save(fn_args.serving_model_dir, save_format='tf')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qcmSNiqq5QaV"
      },
      "source": [
        "Run the `Trainer` component."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4AzsMk7oflMg"
      },
      "outputs": [],
      "source": [
        "trainer = Trainer(\n",
        "    module_file=os.path.abspath(trainer_module_file),\n",
        "    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),\n",
        "    examples=example_gen.outputs['examples'],\n",
        "    schema=infer_schema.outputs['schema'],\n",
        "    train_args=trainer_pb2.TrainArgs(num_steps=100),\n",
        "    eval_args=trainer_pb2.EvalArgs(num_steps=50))\n",
        "\n",
        "interactive_context.run(trainer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gdCq5c0f5MyA"
      },
      "source": [
        "### Evaluate and push the model\n",
        "\n",
        "Use the `Evaluator` component to evaluate and 'bless' the model before using the `Pusher` component to push the model to a serving directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NDx-fTUb6RUU"
      },
      "outputs": [],
      "source": [
        "_serving_model_dir = os.path.join(tempfile.mkdtemp(), 'serving_model/iris_classification')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PpS4-wCf6eLR"
      },
      "outputs": [],
      "source": [
        "eval_config = tfma.EvalConfig(model_specs=[tfma.ModelSpec(label_key ='variety')],\n",
        "                              metrics_specs =[tfma.MetricsSpec(metrics = \n",
        "                                                               [tfma.MetricConfig(class_name='ExampleCount'),\n",
        "                                                               tfma.MetricConfig(class_name='BinaryAccuracy',\n",
        "                                                                  threshold=tfma.MetricThreshold(\n",
        "                                                                      value_threshold=tfma.GenericValueThreshold(\n",
        "                                                                          lower_bound={'value': 0.5}),\n",
        "                                                                      change_threshold=tfma.GenericChangeThreshold(\n",
        "                                                                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,\n",
        "                                                                          absolute={'value': -1e-10})))])],\n",
        "                              slicing_specs = [tfma.SlicingSpec(),\n",
        "                                               tfma.SlicingSpec(feature_keys=['sepal_length'])])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kFuH1YTh8vSf"
      },
      "outputs": [],
      "source": [
        "model_resolver = ResolverNode(\n",
        "      instance_name='latest_blessed_model_resolver',\n",
        "      resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,\n",
        "      model=Channel(type=Model),\n",
        "      model_blessing=Channel(type=ModelBlessing))\n",
        "interactive_context.run(model_resolver)\n",
        "\n",
        "evaluator = Evaluator(\n",
        "    examples=example_gen.outputs['examples'],\n",
        "    model=trainer.outputs['model'],\n",
        "    baseline_model=model_resolver.outputs['model'],\n",
        "    eval_config=eval_config)\n",
        "interactive_context.run(evaluator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NCV9gcCQ966W"
      },
      "outputs": [],
      "source": [
        "pusher = Pusher(\n",
        "    model=trainer.outputs['model'],\n",
        "    model_blessing=evaluator.outputs['blessing'],\n",
        "    push_destination=pusher_pb2.PushDestination(\n",
        "        filesystem=pusher_pb2.PushDestination.Filesystem(\n",
        "            base_directory=_serving_model_dir)))\n",
        "interactive_context.run(pusher)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9K7RzdBzkru7"
      },
      "source": [
        "Running the TFX pipeline populates the MLMD Database. In the next section, you use the MLMD API to query this database for metadata information. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6GRCGQu7RguC"
      },
      "source": [
        "## Query the MLMD Database\n",
        "\n",
        "The MLMD database stores three types of metadata: \n",
        "\n",
        "*    Metadata about the pipeline and lineage information associated with the pipeline components\n",
        "*    Metadata about artifacts that were generated during the pipeline run\n",
        "*    Metadata about the executions of the pipeline\n",
        "\n",
        "A typical production environment pipeline serves multiple models as new data arrives. When you encounter erroneous results in served models, you can query the MLMD database to isolate the erroneous models. You can then trace the lineage of the pipeline components that correspond to these models to debug your models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o0xVYqAkJybK"
      },
      "source": [
        "Set up the metadata (MD) store with the `InteractiveContext` defined previously to query the MLMD database."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P1p38etAv0kC"
      },
      "outputs": [],
      "source": [
        "#md_store = mlmd.MetadataStore(interactive_context.metadata_connection_config)\n",
        "store = mlmd.MetadataStore(interactive_context.metadata_connection_config)\n",
        "\n",
        "# All TFX artifacts are stored in the base directory\n",
        "base_dir = interactive_context.metadata_connection_config.sqlite.filename_uri.split('metadata.sqlite')[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uq-1ep4suvuZ"
      },
      "source": [
        "Create some helper functions to view the data from the MD store."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q1ib8yStu6CW"
      },
      "outputs": [],
      "source": [
        "def display_types(types):\n",
        "  # Helper function to render dataframes for the artifact and execution types\n",
        "  table = {'id': [], 'name': []}\n",
        "  for a_type in types:\n",
        "    table['id'].append(a_type.id)\n",
        "    table['name'].append(a_type.name)\n",
        "  return pd.DataFrame(data=table)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HmqzYZcV3UG5"
      },
      "outputs": [],
      "source": [
        "def display_artifacts(store, artifacts):\n",
        "  # Helper function to render dataframes for the input artifacts\n",
        "  table = {'artifact id': [], 'type': [], 'uri': []}\n",
        "  for a in artifacts:\n",
        "    table['artifact id'].append(a.id)\n",
        "    artifact_type = store.get_artifact_types_by_id([a.type_id])[0]\n",
        "    table['type'].append(artifact_type.name)\n",
        "    table['uri'].append(a.uri.replace(base_dir, './'))\n",
        "  return pd.DataFrame(data=table)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iBdGCZ0CMJDO"
      },
      "outputs": [],
      "source": [
        "def display_properties(store, node):\n",
        "  # Helper function to render dataframes for artifact and execution properties\n",
        "  table = {'property': [], 'value': []}\n",
        "  for k, v in node.properties.items():\n",
        "    table['property'].append(k)\n",
        "    table['value'].append(\n",
        "        v.string_value if v.HasField('string_value') else v.int_value)\n",
        "  for k, v in node.custom_properties.items():\n",
        "    table['property'].append(k)\n",
        "    table['value'].append(\n",
        "        v.string_value if v.HasField('string_value') else v.int_value)\n",
        "  return pd.DataFrame(data=table)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1B-jRNH0M0k4"
      },
      "source": [
        "First, query the MD store for a list of all its stored `ArtifactTypes`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6zXSQL8s5dyL"
      },
      "outputs": [],
      "source": [
        "display_types(store.get_artifact_types())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "quOsBgtM3r7S"
      },
      "source": [
        "Next, query all `PushedModel` artifacts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bUv_EI-bEMMu"
      },
      "outputs": [],
      "source": [
        "pushed_models = store.get_artifacts_by_type(\"PushedModel\")\n",
        "display_artifacts(store, pushed_models)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UecjkVOqJCBE"
      },
      "source": [
        "Query the MD store for the latest pushed model. This tutorial has only one pushed model. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N8tPvRtcPTrU"
      },
      "outputs": [],
      "source": [
        "pushed_model = pushed_models[-1]\n",
        "display_properties(store, pushed_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f5Mz4vfP6wHO"
      },
      "source": [
        "One of the first steps in debugging a pushed model is to look at which trained model is pushed and to see which training data is used to train that model.  \n",
        "\n",
        "MLMD provides traversal APIs to walk through the provenance graph, which you can use to analyze the model provenance. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BLfydQVxOwf3"
      },
      "outputs": [],
      "source": [
        "def get_one_hop_parent_artifacts(store, artifacts):\n",
        "  # Get a list of artifacts within a 1-hop neighborhood of the artifacts of interest\n",
        "  artifact_ids = [artifact.id for artifact in artifacts]\n",
        "  executions_ids = set(\n",
        "      event.execution_id\n",
        "      for event in store.get_events_by_artifact_ids(artifact_ids)\n",
        "      if event.type == metadata_store_pb2.Event.OUTPUT)\n",
        "  artifacts_ids = set(\n",
        "      event.artifact_id\n",
        "      for event in store.get_events_by_execution_ids(executions_ids)\n",
        "      if event.type == metadata_store_pb2.Event.INPUT) \n",
        "  return [artifact for artifact in store.get_artifacts_by_id(artifacts_ids)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3G0e0WIE9e9w"
      },
      "source": [
        "Query the parent artifacts for the pushed model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pOEFxucJQ1i6"
      },
      "outputs": [],
      "source": [
        "parent_artifacts = get_one_hop_parent_artifacts(store, [pushed_model])\n",
        "display_artifacts(store, parent_artifacts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pJror5mf-W0M"
      },
      "source": [
        "Query the properties for the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OSCb0bg6Qmj4"
      },
      "outputs": [],
      "source": [
        "exported_model = parent_artifacts[0]\n",
        "display_properties(store, exported_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "phz1hfzc_UcK"
      },
      "source": [
        "Query the upstream artifacts for the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nx_-IVhjRGA4"
      },
      "outputs": [],
      "source": [
        "model_parents = get_one_hop_parent_artifacts(store, [exported_model])\n",
        "display_artifacts(store, model_parents)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "00jqfk6o_niu"
      },
      "source": [
        "Get the training data the model trained with."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2nMECsKvROEX"
      },
      "outputs": [],
      "source": [
        "used_data = model_parents[0]\n",
        "display_properties(store, used_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GgTMTaew_3Fe"
      },
      "source": [
        "Now that you have the training data that the model trained with, query the database again to find the training step (execution). Query the MD store for a list of the registered execution types."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8cBKQsScaD9a"
      },
      "outputs": [],
      "source": [
        "display_types(store.get_execution_types())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wxcue6SggQ_b"
      },
      "source": [
        "The training step is the `ExecutionType` named `tfx.components.trainer.component.Trainer`. Traverse the MD store to get the trainer run that corresponds to the pushed model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ned8BxHzaunk"
      },
      "outputs": [],
      "source": [
        "def find_producer_execution(store, artifact):\n",
        "  executions_ids = set(\n",
        "    event.execution_id\n",
        "    for event in store.get_events_by_artifact_ids([artifact.id])\n",
        "      if event.type == metadata_store_pb2.Event.OUTPUT)  \n",
        "  return store.get_executions_by_id(executions_ids)[0]\n",
        "\n",
        "trainer = find_producer_execution(store, exported_model)\n",
        "display_properties(store, trainer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CYzlTckHClxC"
      },
      "source": [
        "## Summary\n",
        "\n",
        "In this tutorial, you learned about how you can leverage MLMD to trace the lineage of your TFX pipeline components and resolve issues.\n",
        "\n",
        "To learn more about how to use MLMD, check out these additional resources:\n",
        "\n",
        "* [MLMD API documentation](https://www.tensorflow.org/tfx/ml_metadata/api_docs/python/mlmd)\n",
        "* [MLMD guide](https://www.tensorflow.org/tfx/guide/mlmd)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "mlmd_tutorial.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
